Performance Above Random Expectation (PARE)

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggplot2)
library(AUC)
## AUC 0.3.0
## Type AUCNews() to see the change log and ?AUC to get an overview.
dir.create("Figures", showWarnings = FALSE)

defaultWidth <- 8
defaultHeight <- 4.5
narrowWidth <- 4

saveFigure <- function(description1, description2, width=defaultWidth, height=defaultHeight)
{
  ggsave(paste("Figures/", gsub(" ", "_", description1), "__", description2, ".png", sep=""), width=width, height=height, units="in")
}
generateNumericScores = function(means, randomSeed=999, n=numRandomValues)
{
  if (length(means) < 1)
    stop("Must provide at least one mean.")

  set.seed(randomSeed)

  scores = NULL
  for (i in 1:length(means))
    scores = c(scores, rnorm(n / length(means), mean=means[i]))

  return(standardize(scores))
}

generateNumericScoresImbalanced = function(means, randomSeed=0, n=numRandomValues)
{
  if (length(means) != 2)
    stop("Must provide exactly two means.")
  
  set.seed(randomSeed)

  scores = rnorm(n * imbalanceProportion, mean=means[1])
  scores = c(scores, rnorm(n * (1 - imbalanceProportion), mean=means[2]))

  return(standardize(scores))
}

standardize = function(x)
{
  (x - min(x)) / (max(x) - min(x))
}

numRandomValues = 1000
imbalanceProportion = 0.95

actual = factor(c(rep(0, numRandomValues / 2), rep(1, numRandomValues / 2)))
actualImbalanced = factor(c(rep(0, numRandomValues * imbalanceProportion), rep(1, numRandomValues * (1 - imbalanceProportion))))

randomScores = generateNumericScores(0)
mediumSignalScores = generateNumericScores(c(0, 1))
signalScores = generateNumericScores(c(0, 10))
oppositeScores = generateNumericScores(c(10, 0))
aFewGoodScores = standardize(c(rnorm(990), rnorm(10, mean=10)))

mediumSignalScoresImbalanced = generateNumericScoresImbalanced(c(0, 1))
signalScoresImbalanced = generateNumericScoresImbalanced(c(0, 10))
oppositeScoresImbalanced = generateNumericScoresImbalanced(c(10, 0))
plotBarPlot <- function(values, main, width=defaultWidth)
{
  plotData <- data.frame(Scores=values)

  print(ggplot(plotData, aes(x=Scores)) + geom_bar() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()) + xlab("") + ylab("") + ggtitle(main))
  saveFigure(main, "Barplot", width=width)
}

plotHist <- function(values, main, width=defaultWidth)
{
  plotData <- data.frame(Scores=values)

  print(ggplot(plotData, aes(x=Scores)) + geom_histogram() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()) + xlab("") + ylab("") + ggtitle(main))
  saveFigure(main, "Histogram", width=width)  
}

plotBoxPlot <- function(scores, actual, main, width=defaultWidth)
{
  plotData <- data.frame(Scores=scores, Actual=actual)

  print(ggplot(plotData, aes(factor(Actual), Scores)) + geom_boxplot() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()) + xlab("") + ylab("") + ggtitle(main))
  saveFigure(main, "BoxPlot", width=width)
}

plotBarPlot(actual, main="Class Labels", width=narrowWidth)

plotBarPlot(actualImbalanced, main="Imbalanced Class Labels", width=narrowWidth)

plotHist(randomScores, main="Random Scores")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotHist(mediumSignalScores, main="Medium Signal Scores")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotHist(signalScores, main="Signal Scores")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotHist(oppositeScores, main="Opposite Scores")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotHist(aFewGoodScores, main="A Few Good Scores")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotBoxPlot(randomScores, actual, main="Random Scores")

plotBoxPlot(mediumSignalScores, actual, main="Medium Signal Scores")

plotBoxPlot(signalScores, actual, main="Signal Scores")

plotBoxPlot(oppositeScores, actual, main="Opposite Scores")

plotBoxPlot(aFewGoodScores, actual, main="A Few Good Scores", width=narrowWidth)

plotHist(mediumSignalScoresImbalanced, main="Medium Signal Scores Imbalanced")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotHist(signalScoresImbalanced, main="Signal Scores Imbalanced")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotHist(oppositeScoresImbalanced, main="Opposite Scores Imbalanced")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

plotBoxPlot(mediumSignalScoresImbalanced, actualImbalanced, main="Medium Signal Scores Imbalanced")

plotBoxPlot(signalScoresImbalanced, actualImbalanced, main="Signal Scores Imbalanced")

plotBoxPlot(oppositeScoresImbalanced, actualImbalanced, main="Opposite Scores Imbalanced")

plotROC = function(actual, scores, main, width=defaultWidth)
{
  rocResult = roc(scores, actual)
  aucValue = auc(rocResult)
  
  main2 = paste(main, "\n", "(AUC = ", format(aucValue, digits=3, nsmall=3), ")", sep="")
  
  plotData <- data.frame(TPR=rocResult$tpr, FPR=rocResult$fpr)
  print(ggplot(plotData, aes(x=FPR, y=TPR)) + geom_abline(color="darkgrey", linetype="dashed") + geom_line() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()) + xlab("False positive rate") + ylab("True positive rate") + ggtitle(main2))
  saveFigure(main, "ROC", width=width)
}

plotROC(actual, randomScores, "Random Scores")

plotROC(actual, mediumSignalScores, main="Medium Signal Scores")

plotROC(actual, signalScores, main="Signal Scores")

plotROC(actual, oppositeScores, main="Opposite Scores")

plotROC(actual, aFewGoodScores, main="A Few Good Scores", width=narrowWidth)

plotROC(actualImbalanced, randomScores, main="Random Scores Imbalanced")

plotROC(actualImbalanced, mediumSignalScoresImbalanced, main="Medium Signal Scores Imbalanced")

plotROC(actualImbalanced, signalScoresImbalanced, main="Signal Scores Imbalanced")

plotROC(actualImbalanced, oppositeScoresImbalanced, main="Opposite Scores Imbalanced")

calculateAccuracy <- function(labels, predictors, threshold)
{
  sum(as.integer(predictors > threshold) == labels) / length(labels)
}

calculateNumCorrect <- function(labels, predictors, threshold)
{
  sum(as.integer(predictors > threshold) == labels)
}

calculateSensitivity <- function(labels, predictors, threshold)
{
  predictorsForPositiveLabels <- predictors[which(labels==1)]
  sum(predictorsForPositiveLabels > threshold) / sum(labels == 1)
}

calculateSpecificity <- function(labels, predictors, threshold)
{
  predictorsForNegativeLabels <- predictors[which(labels==0)]
  sum(predictorsForNegativeLabels <= threshold) / sum(labels == 0)
}

calculatePPV <- function(labels, predictors, threshold)
{
  if (length(predictors > threshold) == 0)
    return(0)
  
  length(intersect(which(predictors > threshold), which(labels == 1))) / sum(predictors > threshold)
}

calculateNPV <- function(labels, predictors, threshold)
{
  if (length(predictors <= threshold) == 0)
    return(0)
  
  length(intersect(which(predictors <= threshold), which(labels == 0))) / sum(predictors <= threshold)
}

calculateFscore <- function(labels, predictors, threshold)
{
  precision <- calculatePPV(labels, predictors, threshold)
  recall <- calculateSensitivity(labels, predictors, threshold)
  
  (2 * (precision * recall)) / (precision + recall)
}

checkPareInputs <- function(labels, predictors)
{
  if (!is.factor(labels))
    stop("The labels must be a factor object.")

  if (nlevels(labels) != 2)
    stop("The labels factor object must have exactly two levels.")

  if (length(intersect(c(0, 1), levels(labels))) != 2)
    stop("The factor levels of the labels object must be c(0, 1).")
  
  if (!is.numeric(predictors))
    stop("The predictors must be numeric.")
  
  if (length(predictors) != length(labels))
    stop("The length of the predictors and labels must be the same.")
  
  if (any(predictors > 1 | predictors < 0))
    stop("Predictions must fall between 0 and 1.")
}

calculateScoresAtThresholds <- function(labels, predictors, metricCalculationFunction=calculateAccuracy, metricLabel="Accuracy", numPermutations=10)
{
  checkPareInputs(labels, predictors)
  
  # Combine the labels and predictors so we can sort them together
  combined = cbind(as.integer(labels) - 1, predictors)
  combined = combined[order(labels),]

  # Then separate them back out
  classLabels = combined[,1]
  classPredictors = combined[,2]

  thresholdStepSize = (max(classPredictors) - min(classPredictors)) / (length(classPredictors) + 1)
  thresholds = seq(min(classPredictors), max(classPredictors), thresholdStepSize)
  thresholds = thresholds[2:(length(thresholds)-1)]

  actualResultAtThresholds = calculateMetricAtThresholds(classLabels, classPredictors, thresholds, metricCalculationFunction)
  
  permutedMatrix = NULL
  for (i in 1:numPermutations)
  {
    set.seed(i)
    iPermuted = calculateMetricAtThresholds(sample(classLabels), classPredictors, thresholds, metricCalculationFunction)
    permutedMatrix <- rbind(permutedMatrix, iPermuted)
  }
  permutedResultAtThresholds <- apply(permutedMatrix, 2, mean)
  
  idealPredictors = sort(classPredictors)
  idealResultAtThresholds = calculateMetricAtThresholds(classLabels, idealPredictors, thresholds, metricCalculationFunction)

  wrongPredictors = sort(classPredictors, decreasing=TRUE)
  wrongResultAtThresholds = calculateMetricAtThresholds(classLabels, wrongPredictors, thresholds, metricCalculationFunction)
  
  whichThresholdsToKeep <- which(idealResultAtThresholds != permutedResultAtThresholds & idealResultAtThresholds != wrongResultAtThresholds & wrongResultAtThresholds != permutedResultAtThresholds)
  
  # This indicates that all predictions were the same.
  if (length(whichThresholdsToKeep) == 0) {
    pareScoreAtThresholds <- permutedResultAtThresholds
  } else {
    thresholds <- thresholds[whichThresholdsToKeep]
    actualResultAtThresholds <- actualResultAtThresholds[whichThresholdsToKeep]
    permutedResultAtThresholds <- permutedResultAtThresholds[whichThresholdsToKeep]
    idealResultAtThresholds <- idealResultAtThresholds[whichThresholdsToKeep]
    wrongResultAtThresholds <- wrongResultAtThresholds[whichThresholdsToKeep]
    
    pareScoreAtThresholds <- calculatePareScores(actualResultAtThresholds, permutedResultAtThresholds, idealResultAtThresholds, wrongResultAtThresholds)
  }
  
  thresholdData <- data.frame(ScoreType="Maximal", Score=idealResultAtThresholds, Threshold=thresholds)
  thresholdData <- rbind(thresholdData, data.frame(ScoreType="Actual", Score=actualResultAtThresholds, Threshold=thresholds))
  thresholdData <- rbind(thresholdData, data.frame(ScoreType="Permuted", Score=permutedResultAtThresholds, Threshold=thresholds))
  thresholdData <- rbind(thresholdData, data.frame(ScoreType="Minimal", Score=wrongResultAtThresholds, Threshold=thresholds))
  thresholdData <- rbind(thresholdData, data.frame(ScoreType="PARE", Score=pareScoreAtThresholds, Threshold=thresholds))
  
  return(thresholdData)
}

calculatePareScores <- function(actualResultAtThresholds, permutedResultAtThresholds, idealResultAtThresholds, wrongResultAtThresholds)
{
  scoreAtThresholds <- NULL
  for (i in 1:length(actualResultAtThresholds))
  {
    actualResult <- actualResultAtThresholds[i]
    permutedResult <- permutedResultAtThresholds[i]
    idealResult <- idealResultAtThresholds[i]
    wrongResult <- wrongResultAtThresholds[i]

    actualDiff = actualResult - permutedResult
    baselineDiff = idealResult - permutedResult

    score = 0
    if (baselineDiff != 0)
      score = actualDiff / baselineDiff

    if (score < 0)
    {
      baselineDiff = permutedResult - wrongResult
      score = actualDiff / baselineDiff
    }
    
    scoreAtThresholds <- c(scoreAtThresholds, score)
  }
  
  return(scoreAtThresholds)
}

calculatePareScore <- function(thresholdData)
{
  return(mean(filter(thresholdData, ScoreType=="PARE")$Score))
}

getMainForPlot <- function(thresholdData, main)
{
  pareScore <- calculatePareScore(thresholdData)
  
  mainScore <- paste("(Score = ", format(pareScore, digits=3, nsmall=3), ")", sep="")
  
  if (main == "") {
   main = mainScore
  } else
  {
   main = paste(main, "\n", mainScore, sep="")
  }
  
  return(main)
}

plotPareScores <- function(thresholdData, main="", width=defaultWidth, height=defaultHeight)
{
  plotData <- filter(thresholdData, ScoreType!="PARE")
  
  print(ggplot(plotData, aes(x=Threshold, y=Score, group=ScoreType, colour=ScoreType)) + geom_line() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()) + ylab(metricLabel) + ggtitle(getMainForPlot(thresholdData, main)))
  saveFigure(main, paste("PARE_", metricLabel, sep=""), width=width, height=height)
}

plotPareThresholdSelection <- function(thresholdData, main="", width=defaultWidth, height=defaultHeight)
{
  pareScores <- filter(thresholdData, ScoreType=="PARE")$Score
  weights <- filter(thresholdData, ScoreType=="Actual")$Score
  weightedPareScores <- pareScores * weights
  #weightedPareScores <- pareScores * weights^2

  plotData <- filter(thresholdData, ScoreType=="PARE")
  plotData$Score <- weightedPareScores

  print(ggplot(plotData, aes(x=Threshold, y=Score)) + geom_line() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()) + ylab("PARE") + ggtitle(getMainForPlot(thresholdData, main)) + ylim(-1, 1) + geom_hline(yintercept=0, color="red", linetype="dashed"))

  saveFigure(main, paste("ThresholdSelection_", metricLabel, sep=""), width=width, height=height)
  
  # + geom_vline(xintercept=bestThreshold, color="red", linetype="dashed")
  # + theme(axis.text.y=element_blank(), axis.ticks.y=element_blank())
}

pare <- function(labels, predictors, main="", plot=TRUE, metricCalculationFunction=calculateAccuracy, metricLabel="Accuracy", width=defaultWidth, height=defaultHeight)
{
  checkPareInputs(labels, predictors)
  thresholdData <- calculateScoresAtThresholds(labels, predictors, metricCalculationFunction, metricLabel)

  if (plot)
  {
    plotPareScores(thresholdData, main, width=width, height=height)
    plotPareThresholdSelection(thresholdData, main, width=width, height=height)
  }
  
  #return(calculatePareScore(thresholdData))
}

calculateMetricAtThresholds <- function(labels, predictors, thresholds, metricCalculationFunction)
{
  sapply(thresholds, function(x) { metricCalculationFunction(labels, predictors, x) })
}

metricCalculationFunction <- calculateAccuracy
metricLabel <- "Accuracy"
#metricCalculationFunction <- calculateSensitivity
#metricLabel <- "Sensitivity"
#metricCalculationFunction <- calculateSpecificity
#metricLabel <- "Specificity"
#metricCalculationFunction <- calculatePPV
#metricLabel <- "PPV"
#metricCalculationFunction <- calculateNPV
#metricLabel <- "NPV"
#metricCalculationFunction <- calculateFscore
#metricLabel <- "F-score"

pare(actual, randomScores, main="Random Scores", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel)

metricCalculationFunction <- calculateAccuracy
metricLabel <- "Accuracy"
width <- defaultWidth
#metricCalculationFunction <- calculateSensitivity
#metricLabel <- "Sensitivity"
#width <- narrowWidth
#metricCalculationFunction <- calculateSpecificity
#metricLabel <- "Specificity"
#width <- narrowWidth

pare(actual, randomScores, main="Random Scores", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=narrowWidth)

pare(actual, mediumSignalScores, main="Medium Signal Scores", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=width)

pare(actual, signalScores, main="Perfect Scores", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=narrowWidth)

pare(actual, oppositeScores, main="Opposite Scores", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=width)

pare(actual, aFewGoodScores, main="A Few Good Scores", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=defaultWidth)

pare(actualImbalanced, randomScores, main="Random Scores Imbalanced", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=width)

pare(actualImbalanced, mediumSignalScoresImbalanced, main="Medium Signal Scores Imbalanced", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=narrowWidth)

pare(actualImbalanced, signalScoresImbalanced, main="Signal Scores Imbalanced", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=width)

pare(actualImbalanced, oppositeScoresImbalanced, main="Opposite Scores Imbalanced", metricCalculationFunction=metricCalculationFunction, metricLabel=metricLabel, width=width)

compareAUCvPARE = function(meanOptions, iterations=20)
{
  aucScores = NULL
  pareScores = NULL

  for (i in 1:iterations)
  {
    message(paste("Iteration ", i, "/", iterations, sep=""))
    for (j in 1:length(meanOptions))
    {
      meanOption = meanOptions[j]
      predictors = generateNumericScores(c(0, meanOption), i*j)

      aucScores = c(aucScores, auc(roc(predictors, actual)))
      pareScores = c(pareScores, pare(actual, predictors, plot=FALSE))
    }
  }

  #hist(aucScores, main=paste("AUC:", format(mean(aucScores), digits=3, nsmall=3)))
  #hist(pareScores, main=paste("PARE:", format(mean(pareScores), digits=3, nsmall=3)))
  
  plotData <- data.frame(AUC=aucScores, PARE=pareScores)
  print(ggplot(plotData, aes(x=AUC, y=PARE)) + geom_point() + theme_bw() + theme(legend.key = element_blank(), legend.title=element_blank()))
  saveFigure("AUC", "PARE")
}

#compareAUCvPARE(rnorm(100), iterations=100)